# Western Blot Band Detection - Synthetic Dataset Generator

Generate synthetic Western blot images with annotations for training object detection models using **PyTorch** and **HuggingFace**.

## Features

- **Row-aligned bands**: Bands appear in consistent horizontal rows across lanes
- **Row tilt**: Each row can tilt slightly upward or downward
- **Ladder position**: Marker can be first (left) or last (right) lane
- **One annotation per blob**: Overlapping annotations are merged
- **Irregular loading**: Variable band intensity across lanes
- **Faint bands**: Thin/faint bands NOT annotated (model learns to ignore)
- **Dual format**: YOLO (.txt) and COCO (.json) annotations

## Installation

```bash
pip install torch torchvision
pip install transformers      # For DETR
pip install opencv-python numpy matplotlib pyyaml
pip install pycocotools       # For COCO evaluation
```

## Quick Start

### 1. Preview samples

```bash
python visualize_samples.py --num 6 --output preview.png
```

### 2. Generate dataset

```bash
# Small dataset for testing
python generate_dataset.py --output ./dataset --train 100 --val 20 --test 20

# Full dataset for training
python generate_dataset.py --output ./dataset --train 1000 --val 200 --test 200
```

### 3. Train model

**Option A: Faster R-CNN (torchvision)**
```bash
python train.py --data ./dataset --model fasterrcnn --epochs 50
```

**Option B: RetinaNet (torchvision)**
```bash
python train.py --data ./dataset --model retinanet --epochs 50
```

**Option C: DETR (HuggingFace Transformers)**
```bash
python train.py --data ./dataset --model detr --epochs 50
```

### 4. Run inference

```bash
# With Faster R-CNN
python predict.py --model ./output/best_model.pt --image your_blot.png --type fasterrcnn

# With DETR
python predict.py --model ./output/best_model --image your_blot.png --type detr
```

## Dataset Structure

```
dataset/
├── data.yaml               # YOLO config
├── train/
│   ├── images/             # Training images (.png)
│   ├── labels/             # YOLO annotations (.txt)
│   └── annotations.json    # COCO format (for PyTorch)
├── val/
│   ├── images/
│   ├── labels/
│   └── annotations.json
└── test/
    ├── images/
    ├── labels/
    └── annotations.json
```

## Annotation Formats

### YOLO format (labels/*.txt)
```
class_id x_center y_center width height
0 0.332500 0.512000 0.177000 0.075000
```

### COCO format (annotations.json)
```json
{
  "images": [{"id": 0, "file_name": "...", "width": 800, "height": 500}],
  "annotations": [{"id": 0, "image_id": 0, "bbox": [x, y, w, h], "category_id": 0}],
  "categories": [{"id": 0, "name": "band"}]
}
```

## Model Comparison

| Model | Framework | Speed | Accuracy | Best For |
|-------|-----------|-------|----------|----------|
| Faster R-CNN | torchvision | Medium | High | General use |
| RetinaNet | torchvision | Fast | Good | Speed-focused |
| DETR | HuggingFace | Slow | Highest | Best accuracy |

## Customization

### Generator parameters

```python
from western_blot_generator import generate_western_blot_with_annotations

img, annotations = generate_western_blot_with_annotations(
    width=800,
    height=500,
    num_lanes=4,
    include_ladder=True,
    ladder_position='auto',         # 'first', 'last', or 'auto'
    num_protein_rows=(2, 6),        # Min/max protein rows
    bands_per_row_probability=0.7,  # Chance each lane has band
    row_tilt_range=(-15, 15),       # Row tilt in pixels
    skew_range=(-3, 3),             # Global skew (degrees)
    irregular_loading=True,
    add_faint_bands=True,
    min_annotate_intensity=0.4,     # Min intensity to annotate
    min_annotate_height=12,         # Min height to annotate
)
```

## Files

| File | Description |
|------|-------------|
| `western_blot_generator.py` | Core generator with merged annotations |
| `generate_dataset.py` | Generate YOLO + COCO format dataset |
| `train.py` | PyTorch/HuggingFace training script |
| `predict.py` | Inference script |
| `visualize_samples.py` | Preview generated images |

## Training Tips

1. **Start small**: 100-200 images to verify pipeline works
2. **GPU recommended**: Training is much faster with CUDA
3. **Batch size**: Reduce if you run out of memory (try 2 or 1)
4. **Epochs**: 50-100 usually sufficient for synthetic data
5. **Fine-tuning**: Consider fine-tuning on real Western blots after

## Example Usage in Code

```python
import torch
from PIL import Image
from torchvision import transforms as T

# Load trained model
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
model.load_state_dict(torch.load('output/best_model.pt'))
model.eval()

# Run inference
image = Image.open('your_blot.png').convert('RGB')
image_tensor = T.ToTensor()(image).unsqueeze(0)

with torch.no_grad():
    predictions = model(image_tensor)[0]

# Filter predictions
conf_threshold = 0.5
mask = predictions['scores'] > conf_threshold
boxes = predictions['boxes'][mask]
scores = predictions['scores'][mask]

print(f"Detected {len(boxes)} bands")
```
